# -*- coding: utf-8 -*-
### Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/
### Modified by Lesly Miculicich <lmiculicich@idiap.ch>
# 
# This file is part of HAN-NMT.
# 
# HAN-NMT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
# 
# HAN-NMT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with HAN-NMT. If not, see <http://www.gnu.org/licenses/>.

import os, sys
from collections import Counter, defaultdict, OrderedDict
from itertools import count

import torch
import torchtext.data
import torchtext.vocab
import numpy as np

from onmt.io.DatasetBase import UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD
from onmt.io.TextDataset import TextDataset
from onmt.io.ImageDataset import ImageDataset
from onmt.io.AudioDataset import AudioDataset


def _getstate(self):
	return dict(self.__dict__, stoi=dict(self.stoi))


def _setstate(self, state):
	self.__dict__.update(state)
	self.stoi = defaultdict(lambda: 0, self.stoi)


torchtext.vocab.Vocab.__getstate__ = _getstate
torchtext.vocab.Vocab.__setstate__ = _setstate


def get_fields(data_type, n_src_features, n_tgt_features):
	"""
	Args:
		data_type: type of the source input. Options are [text|img|audio].
		n_src_features: the number of source features to
			create `torchtext.data.Field` for.
		n_tgt_features: the number of target features to
			create `torchtext.data.Field` for.

	Returns:
		A dictionary whose keys are strings and whose values are the
		corresponding Field objects.
	"""
	if data_type == 'text':
		return TextDataset.get_fields(n_src_features, n_tgt_features)
	elif data_type == 'img':
		return ImageDataset.get_fields(n_src_features, n_tgt_features)
	elif data_type == 'audio':
		return AudioDataset.get_fields(n_src_features, n_tgt_features)


def load_fields_from_vocab(vocab, data_type="text"):
	"""
	Load Field objects from `vocab.pt` file.
	"""
	vocab = dict(vocab)
	n_src_features = len(collect_features(vocab, 'src'))
	n_tgt_features = len(collect_features(vocab, 'tgt'))
	fields = get_fields(data_type, n_src_features, n_tgt_features)
	for k, v in vocab.items():
		# Hack. Can't pickle defaultdict :(
		v.stoi = defaultdict(lambda: 0, v.stoi)
		fields[k].vocab = v
	return fields


def save_fields_to_vocab(fields):
	"""
	Save Vocab objects in Field objects to `vocab.pt` file.
	"""
	vocab = []
	for k, f in fields.items():
		if f is not None and 'vocab' in f.__dict__:
			f.vocab.stoi = dict(f.vocab.stoi)
			vocab.append((k, f.vocab))
	return vocab


def merge_vocabs(vocabs, vocab_size=None):
	"""
	Merge individual vocabularies (assumed to be generated from disjoint
	documents) into a larger vocabulary.

	Args:
		vocabs: `torchtext.vocab.Vocab` vocabularies to be merged
		vocab_size: `int` the final vocabulary size. `None` for no limit.
	Return:
		`torchtext.vocab.Vocab`
	"""
	merged = sum([vocab.freqs for vocab in vocabs], Counter())
	return torchtext.vocab.Vocab(merged,
								 specials=[UNK_WORD, PAD_WORD,
										   BOS_WORD, EOS_WORD],
								 max_size=vocab_size)


def get_num_features(data_type, corpus_file, side):
	"""
	Args:
		data_type (str): type of the source input.
			Options are [text|img|audio].
		corpus_file (str): file path to get the features.
		side (str): for source or for target.

	Returns:
		number of features on `side`.
	"""
	assert side in ["src", "tgt"]

	if data_type == 'text':
		return TextDataset.get_num_features(corpus_file, side)
	elif data_type == 'img':
		return ImageDataset.get_num_features(corpus_file, side)
	elif data_type == 'audio':
		return AudioDataset.get_num_features(corpus_file, side)


def make_features(batch, side, data_type='text'):
	"""
	Args:
		batch (Variable): a batch of source or target data.
		side (str): for source or for target.
		data_type (str): type of the source input.
			Options are [text|img|audio].
	Returns:
		A sequence of src/tgt tensors with optional feature tensors
		of size (len x batch).
	"""
	assert side in ['src', 'tgt']
	if isinstance(batch.__dict__[side], tuple):
		data = batch.__dict__[side][0]
	else:
		data = batch.__dict__[side]

	feat_start = side + "_feat_"
	keys = sorted([k for k in batch.__dict__ if feat_start in k])
	features = [batch.__dict__[k] for k in keys]
	levels = [data] + features

	if data_type == 'text':
		return torch.cat([level.unsqueeze(2) for level in levels], 2)
	else:
		return levels[0]


def collect_features(fields, side="src"):
	"""
	Collect features from Field object.
	"""
	assert side in ["src", "tgt"]
	feats = []
	for j in count():
		key = side + "_feat_" + str(j)
		if key not in fields:
			break
		feats.append(key)
	return feats


def collect_feature_vocabs(fields, side):
	"""
	Collect feature Vocab objects from Field object.
	"""
	assert side in ['src', 'tgt']
	feature_vocabs = []
	for j in count():
		key = side + "_feat_" + str(j)
		if key not in fields:
			break
		feature_vocabs.append(fields[key].vocab)
	return feature_vocabs


def build_dataset(fields, data_type, src_path, tgt_path, doc_path=None, src_dir=None,
				  src_seq_length=0, tgt_seq_length=0,
				  src_seq_length_trunc=0, tgt_seq_length_trunc=0,
				  dynamic_dict=True, sample_rate=0,
				  window_size=0, window_stride=0, window=None,
				  normalize_audio=True, use_filter_pred=True):

	# Build src/tgt examples iterator from corpus files, also extract
	# number of features.
	src_examples_iter, num_src_feats = \
		_make_examples_nfeats_tpl(data_type, src_path, src_dir,
								  src_seq_length_trunc, sample_rate,
								  window_size, window_stride,
								  window, normalize_audio)

	# For all data types, the tgt side corpus is in form of text.
	tgt_examples_iter, num_tgt_feats = \
		TextDataset.make_text_examples_nfeats_tpl(
			tgt_path, tgt_seq_length_trunc, "tgt")

	doc_index = [int(l.strip()) for l in open(doc_path)]

	if data_type == 'text':
		dataset = TextDataset(fields, src_examples_iter, tgt_examples_iter, doc_index,
							  num_src_feats, num_tgt_feats,
							  src_seq_length=src_seq_length,
							  tgt_seq_length=tgt_seq_length,
							  dynamic_dict=dynamic_dict,
							  use_filter_pred=use_filter_pred)

	elif data_type == 'img':
		dataset = ImageDataset(fields, src_examples_iter, tgt_examples_iter,
							   num_src_feats, num_tgt_feats,
							   tgt_seq_length=tgt_seq_length,
							   use_filter_pred=use_filter_pred)

	elif data_type == 'audio':
		dataset = AudioDataset(fields, src_examples_iter, tgt_examples_iter,
							   num_src_feats, num_tgt_feats,
							   tgt_seq_length=tgt_seq_length,
							   sample_rate=sample_rate,
							   window_size=window_size,
							   window_stride=window_stride,
							   window=window,
							   normalize_audio=normalize_audio,
							   use_filter_pred=use_filter_pred)

	return dataset


def _build_field_vocab(field, counter, **kwargs):
	specials = list(OrderedDict.fromkeys(
		tok for tok in [field.unk_token, field.pad_token, field.init_token,
						field.eos_token]
		if tok is not None))
	field.vocab = field.vocab_cls(counter, specials=specials, **kwargs)


def build_vocab(train_dataset_files, fields, data_type, share_vocab,
				src_vocab_path, src_vocab_size, src_words_min_frequency,
				tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency):
	"""
	Args:
		train_dataset_files: a list of train dataset pt file.
		fields (dict): fields to build vocab for.
		data_type: "text", "img" or "audio"?
		share_vocab(bool): share source and target vocabulary?
		src_vocab_path(string): Path to src vocabulary file.
		src_vocab_size(int): size of the source vocabulary.
		src_words_min_frequency(int): the minimum frequency needed to
				include a source word in the vocabulary.
		tgt_vocab_path(string): Path to tgt vocabulary file.
		tgt_vocab_size(int): size of the target vocabulary.
		tgt_words_min_frequency(int): the minimum frequency needed to
				include a target word in the vocabulary.

	Returns:
		Dict of Fields
	"""
	counter = {}
	for k in fields:
		counter[k] = Counter()

	# Load vocabulary
	src_vocab = None
	if len(src_vocab_path) > 0:
		src_vocab = set([])
		#tgt_vocab = set([])
		print('Loading source vocab from %s' % src_vocab_path)
		assert os.path.exists(src_vocab_path), \
			'src vocab %s not found!' % src_vocab_path

		with open(tgt_vocab_path) as f:
			for line in f:
				word = line.strip().split()[0]
				src_vocab.add(word)
			#src_vocab.add(word)
		#for word in f[1][1].stoi.keys():
			#gt_vocab.add(word)

	tgt_vocab = None
	if len(tgt_vocab_path) > 0:
		tgt_vocab = set([])
		print('Loading target vocab from %s' % tgt_vocab_path)
		assert os.path.exists(tgt_vocab_path), \
			'tgt vocab %s not found!' % tgt_vocab_path
		with open(tgt_vocab_path) as f:
			for line in f:
				word = line.strip().split()[0]
				tgt_vocab.add(word)

	for path in train_dataset_files:
		dataset = torch.load(path)
		print(" * reloading %s." % path)
		for ex in dataset.examples:
			for k in fields:
				val = getattr(ex, k, None)
				if val is not None and not fields[k].sequential:
					val = [val]
				elif k == 'src' and src_vocab:
					val = [item for item in val if item in src_vocab]
				elif k == 'tgt' and tgt_vocab:
					val = [item for item in val if item in tgt_vocab]
				counter[k].update(val)

	_build_field_vocab(fields["tgt"], counter["tgt"],
					   max_size=tgt_vocab_size,
					   min_freq=tgt_words_min_frequency)
	print(" * tgt vocab size: %d." % len(fields["tgt"].vocab))

	# All datasets have same num of n_tgt_features,
	# getting the last one is OK.
	for j in range(dataset.n_tgt_feats):
		key = "tgt_feat_" + str(j)
		_build_field_vocab(fields[key], counter[key])
		print(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

	if data_type == 'text':
		_build_field_vocab(fields["src"], counter["src"],
						   max_size=src_vocab_size,
						   min_freq=src_words_min_frequency)
		print(" * src vocab size: %d." % len(fields["src"].vocab))

		# All datasets have same num of n_src_features,
		# getting the last one is OK.
		for j in range(dataset.n_src_feats):
			key = "src_feat_" + str(j)
			_build_field_vocab(fields[key], counter[key])
			print(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

		# Merge the input and output vocabularies.
		if share_vocab:
			# `tgt_vocab_size` is ignored when sharing vocabularies
			print(" * merging src and tgt vocab...")
			merged_vocab = merge_vocabs(
				[fields["src"].vocab, fields["tgt"].vocab],
				vocab_size=src_vocab_size)
			fields["src"].vocab = merged_vocab
			fields["tgt"].vocab = merged_vocab

	return fields


def _make_examples_nfeats_tpl(data_type, src_path, src_dir,
							  src_seq_length_trunc, sample_rate,
							  window_size, window_stride,
							  window, normalize_audio):
	"""
	Process the corpus into (example_dict iterator, num_feats) tuple
	on source side for different 'data_type'.
	"""

	if data_type == 'text':
		src_examples_iter, num_src_feats = \
			TextDataset.make_text_examples_nfeats_tpl(
				src_path, src_seq_length_trunc, "src")

	elif data_type == 'img':
		src_examples_iter, num_src_feats = \
			ImageDataset.make_image_examples_nfeats_tpl(
				src_path, src_dir)

	elif data_type == 'audio':
		src_examples_iter, num_src_feats = \
			AudioDataset.make_audio_examples_nfeats_tpl(
				src_path, src_dir, sample_rate,
				window_size, window_stride, window,
				normalize_audio)

	return src_examples_iter, num_src_feats


class OrderedIterator(torchtext.data.Iterator):
	def create_batches(self):
		if self.train:
			def pool(data, random_shuffler):
				for p in torchtext.data.batch(data, self.batch_size * 100):
					p_batch = torchtext.data.batch(
						sorted(p, key=self.sort_key),
						self.batch_size, self.batch_size_fn)
					for b in random_shuffler(list(p_batch)):
						yield b
			self.batches = pool(self.data(), self.random_shuffler)
		else:
			self.batches = []
			for b in torchtext.data.batch(self.data(), self.batch_size,
										  self.batch_size_fn):
				self.batches.append(sorted(b, key=self.sort_key))

class DocumentIterator(torchtext.data.Iterator):
	

	def __init__(self, dataset, batch_size, device=None,
				 batch_size_fn=None, train=True, shuffle=None,
				 sort_within_batch=None):
		
		super(DocumentIterator, self).__init__(dataset, batch_size, device=device,
				 batch_size_fn=batch_size_fn, train=train,
				 repeat=False, shuffle=False, sort=False,
				 sort_within_batch=sort_within_batch)
		self.doc_index, self.doc_range = self.get_context_index(self.data())
	
		self.indx = None

	def document_shuffler(self):
		shuffler_index = self.random_shuffler(range(len(self.doc_range)))
		docs, indx = [], []
		for i in shuffler_index:
			docs.extend(self.dataset[self.doc_range[i][0]:self.doc_range[i][1]])
			indx.extend(self.doc_index[self.doc_range[i][0]:self.doc_range[i][1]])

		assert len(docs) == len(self.doc_index), "Error in document indexes"
		assert len(indx) == len(self.dataset), "Error in document indexes"

		return docs, np.array(indx)

	def create_batches(self):
		if self.train:
			data, indx = self.document_shuffler()
			self.batches = torchtext.data.batch(data, self.batch_size, self.batch_size_fn)
			self.indx = indx
		else:
			self.batches = self.batch_eval()
			self.indx = np.array(self.doc_index)


	def get_context_index(self, batch):
		d_index, d_range, prev_i, i = [False]*len(batch), [], 0, 0
		for i, m in enumerate(batch):
			if m.indices in self.dataset.doc_index:
				d_index[i] = True
				if prev_i != i: d_range.append((prev_i, i))
				prev_i = i
		if prev_i != i+1: d_range.append((prev_i, i+1))			
		return d_index, d_range

	def __iter__(self):
		while True:
			self.init_epoch()
			count = 0
			for idx, minibatch in enumerate(self.batches):
				# fast-forward if loaded from state
				if self._iterations_this_epoch > idx:
					continue
				self.iterations += 1
				self._iterations_this_epoch += 1
				indx = np.where(self.indx[count:count + len(minibatch)])[0].tolist()
				count += len(minibatch)
				yield torchtext.data.Batch(minibatch, self.dataset, self.device, self.train), indx
			if not self.repeat:
				raise StopIteration

	def batch_eval(self):	
		for r in self.doc_range:
			if r[1]-r[0] > self.batch_size:
				for i in range(int((r[1]-r[0])/self.batch_size)):
					yield self.dataset[r[0]+i*self.batch_size:r[0]+i*self.batch_size+self.batch_size]
				if r[0]+i*self.batch_size+self.batch_size < r[1]:
					yield self.dataset[r[0]+i*self.batch_size+self.batch_size:r[1]]
			else:
				yield self.dataset[r[0]:r[1]]
	
